import os
from prompt_toolkit import PromptSession
from prompt_toolkit import print_formatted_text
from prompt_toolkit.formatted_text import FormattedText
from rich.console import Console, Group
from rich.panel import Panel
from rich.text import Text
from rich.markdown import Markdown
from rich.live import Live
import json
from rich.console import group
from prompt_toolkit import HTML
from rich.theme import Theme

@group()
def get_panels():
    yield Panel("Hello", style="on blue")
    yield Panel("World", style="on red")

class ChatAgent:
    def __init__(self, client, data_dir, task_id=None):
        self.client = client
        self.data_dir = data_dir
        self.history = []
        custom_theme = Theme({
            "markdown": "",
            "markdown.heading": "",
            "markdown.code": "",
            "markdown.pre": "",
            "markdown.link": "",
            "markdown.list": "",
            "markdown.strong": "",
            "markdown.emphasis": "",
            "markdown.block_quote": "",
            "repr.number": "",
            "repr.string": "",
            "repr.string_quote": "",
            "markdown.table": "",
            "markdown.table.header": "",
            "markdown.table.row": "",
            "markdown.table.cell": "",
            "markdown.hrule": ""
        })
        self.console = Console(theme=custom_theme)
        self.task_id = task_id
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)


    def get_history_file_path(self):
        return os.path.join(self.data_dir, f"chat_{self.task_id}.json")

    def save_history(self, history):
        history_file = self.get_history_file_path()
        with open(history_file, "w") as file:
            json.dump(history, file, indent=4)

    def start(self):
        """
        Start the main chat session loop. Continuously gets user input, generates responses, and updates the history.
        """
        # Initial message in a panel
        self.console.print(Text("────────────────────────────────────────────────────────────────\n\n", style="magenta"), end="")
        self.console.print(Text("Enter your message (single line or use ''' for multiple lines).\n", style="blue"))

        panel_group = Group(
            Text("Pentest Muse: ", style="bold green", end=""),
            Text("Hi! How can I help you today?"),
        )
        self.console.print(panel_group)        

        while True:
            user_input = self.get_user_input()
            if user_input is None: 
                break

            self.generate_response()


    def get_user_input(self):
        """
        Get the input from the user.
        """
        session = PromptSession()

        self.console.print(Text("\n"), end="")

        try:
            lines = []
            multiline = False
            while True:
                line = session.prompt(HTML('> ') if not multiline else "")
                if line.strip() == "'''":
                    multiline = not multiline
                elif line.strip() == "exit":
                    return None
                elif line.strip() == "":
                    continue
                else:
                    lines.append(line)
                if not multiline:  # End of multiline input
                    break

            user_input = "\n".join(lines)
            self.history.append({"role": "user", "content": user_input})
            self.save_history(self.history)
            return user_input
        except KeyboardInterrupt:
            if lines:
                return self.get_user_input()
            else:
                return None
        except Exception as e:
            print_formatted_text(FormattedText([("red", f"Error occurred while getting user input: {e}")]))
            return None

    def generate_response(self):
        """
        Generate a response based on the conversation history, updating the output live as the response is received.
        """
        from app.prompts.prompts import CHAT_AGENT_PROMPT
        instruction = CHAT_AGENT_PROMPT

        messages = [{"role": "system", "content": instruction}] + self.history

        full_message_content = ''

        with Live(console=self.console, refresh_per_second=10) as live:
            try:
                panel_group = Group(
                    Text("\nPentest Muse: \n", style="bold green", end=""),
                    Panel('Thinking...', expand=False, border_style="green"),
                )
                live.update(panel_group)

                # Begin the streaming completion with OpenAI
                completion = self.client.generate(messages=messages)

                for chunk in completion:
                    # Update the full message content
                    full_message_content += chunk

                    # Update the live output with the new content
                    response_markdown = Markdown(full_message_content)
                    panel_group = Group(
                        Text("\nPentest Muse: \n", style="bold green", end=""),
                        Panel(response_markdown, expand=False, border_style="green"),
                    )
                    live.update(panel_group)

            except KeyboardInterrupt:
                # Save to history
                self.history.append({"role": "assistant", "content": full_message_content})
                self.save_history(self.history)
                return


        self.history.append({"role": "assistant", "content": full_message_content})
        self.save_history(self.history)

        return
